Fixes for scaled initializers, function inlining, model selection, and added Conv1d support.#1
Fixes for scaled initializers, function inlining, model selection, and added Conv1d support.#1sidcraftscode wants to merge 1 commit intomainfrom
Conversation
…tion, and add Conv1d support
There was a problem hiding this comment.
Pull request overview
This pull request addresses several fixes and enhancements to the Delta compiler:
Changes:
- Fixed model selection logic to prefer the 'forward' module when available
- Added support for Conv1d layers in the FX lowering backend
- Enhanced initializer handling to support scaled initializers (e.g.,
randn(3, 2) * 0.01) - Removed debug print statements from the SIR builder
- Updated constraint property construction to use role-based information
- Modified block result inference by removing ReturnStmt from exclusion list
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 5 comments.
| File | Description |
|---|---|
| delta/run.py | Added BinaryOp import and implemented scaled initializer shape extraction; improved module selection logic |
| delta/ir/sir_builder.py | Removed debug prints, updated constraint properties with ConstraintType and role info, modified block result inference |
| delta/backend/fx_lowering.py | Added Conv1d layer support alongside existing Conv2d |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| props = SIRProperty( | ||
| dtype=FloatType(), | ||
| requires_grad=expr_node.requires_grad, | ||
| dtype=ConstraintType(), |
There was a problem hiding this comment.
The type ConstraintType() does not exist in the codebase. This will cause a NameError at runtime when building constraint statements. The type should be FloatType() instead, as constraints evaluate to scalar penalty terms that are added to the objective. The original code used FloatType() for the dtype of constraints.
| dtype=ConstraintType(), | |
| dtype=FloatType(), |
| result = self._build_expr(block.result) | ||
| nodes.append(result) | ||
| elif nodes and not isinstance(block.statements[-1], (ReturnStmt, ParamDecl, ObsDecl)): | ||
| elif nodes and not isinstance(block.statements[-1], (ParamDecl, ObsDecl)): |
There was a problem hiding this comment.
The removal of ReturnStmt from this check may cause incorrect block result inference. When a block ends with a return statement, the return statement itself should not be used as the block's implicit result value. Consider whether this change is intentional and whether it aligns with the semantics of how return statements should be handled in blocks.
| def _extract_shape_from_initializer(self, initializer_expr: Any) -> Optional[tuple]: | ||
| """Extract shape from an initializer expression like randn(3, 2).""" | ||
| from delta.frontend.ast import Call, Identifier, Literal, Tensor | ||
| from delta.frontend.ast import Call, Identifier, Literal, Tensor, BinaryOp |
There was a problem hiding this comment.
The BinaryOp import was added to the top-level imports but is also imported again within the _extract_shape_from_initializer method at line 157. Consider removing the redundant local import since BinaryOp is now available from the module-level import at line 30.
| from delta.frontend.ast import Call, Identifier, Literal, Tensor, BinaryOp | |
| from delta.frontend.ast import Tensor |
| # Handle scaling: randn(...) * 0.01 | ||
| if isinstance(initializer_expr, BinaryOp): | ||
| # Recursively check operands | ||
| shape = self._extract_shape_from_initializer(initializer_expr.left) | ||
| if shape: return shape | ||
| return self._extract_shape_from_initializer(initializer_expr.right) |
There was a problem hiding this comment.
The handling of BinaryOp expressions in _extract_shape_from_initializer is incomplete. While it recursively extracts shapes from binary operations like randn(3, 2) * 0.01, the corresponding _eval_initializer method (lines 90-137) doesn't handle BinaryOp expressions. This means scaled initializers will have their shapes extracted correctly, but the actual initialization won't apply the scaling factor. Consider also updating _eval_initializer to handle BinaryOp expressions to fully support scaled initializers.
| # Ideally this metadata would be extracted during compilation, but | ||
| # for now we use the AST directly since CompileResult exposes it. | ||
| from delta.frontend.ast import ParamDecl, Call, Identifier, Literal | ||
| from delta.frontend.ast import ParamDecl, Call, Identifier, Literal, BinaryOp |
There was a problem hiding this comment.
Import of 'BinaryOp' is not used.
| from delta.frontend.ast import ParamDecl, Call, Identifier, Literal, BinaryOp | |
| from delta.frontend.ast import ParamDecl, Call, Identifier, Literal |
No description provided.